# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
from numpy import array, float32


# Pasted from the test output (see export_back_compat_test_util.py module docstring)
data_2024_05_02 = dict(
    testdata_version=1,
    platform='cuda',
    custom_call_targets=['__gpu$xla.gpu.triton'],
    serialized_date=datetime.date(2024, 5, 2),
    inputs=(array([0., 1., 2., 3., 4., 5., 6., 7.], dtype=float32),),
    expected_outputs=(array([1., 2., 3., 4., 5., 6., 7., 8.], dtype=float32),),
    mlir_module_text=r"""
#loc1 = loc("x")
#loc2 = loc("third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py":43:13)
#loc3 = loc("jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=wrapped keep_unused=False inline=False]"(#loc2))
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<8xf32> {mhlo.layout_mode = "default"} loc("x")) -> (tensor<8xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = call @wrapped(%arg0) : (tensor<8xf32>) -> tensor<8xf32> loc(#loc3)
    return %0 : tensor<8xf32> loc(#loc)
  } loc(#loc)
  func.func private @wrapped(%arg0: tensor<8xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=wrapped keep_unused=False inline=False]"(#loc2))) -> (tensor<8xf32> {mhlo.layout_mode = "default"}) {
    %0 = stablehlo.custom_call @__gpu$xla.gpu.triton(%arg0) {mhlo.backend_config = {debug = false, grid_x = 8 : i32, grid_y = 1 : i32, grid_z = 1 : i32, ir = "ML\EFR\0DMLIRgooglex-trunk\00\01-\07\01\05\09\17\01\03\0F\03\0D\13\17\1B\1F#'\05\09+/37\03O1\0B\01-\07\0F\0F\0F\0F\13\13\13\0B\0F\0F\13\0B\0F\0B\0B\0B\0B\1F\0B\0B\13\05\05YY\01\09\0F\07\17\0B\03\035\02\16\02\1F\11\01\05\1D)+\1D#\0F\11\01\01#\01\01\01\03\03\19\1B\17\11U'\05\1D\11\07\00\1D'\0F\01\05\0D\0D\05\1F\11\01\81\0D\05\05!\05#\05%\13\03\10\00\00\E0\0F\05'\05)\17\11U\11#arith.overflow<none>\00#arith.fastmath<none>\00\01\02\02\0B\05\05\09\09\01\01\09!tt.ptr<f32>\00\04\D2\02\05\01P\01\01\07\04\AE\02\03\01\05\07P\01\03\07\04\82\02\03+W\05\11\11\00\09B\01\05\03\01\0FB\01\07\03\01\11F\01\09\03\01\05\05\07\0FB\01\07\03\01\11F\01\09\03\01\05\05\0B\0FB\07\05\03\01\13F\07\09\03\01\05\0F\09\0FB\07\07\03\01\11F\07\09\03\01\05\11\13\03\06\07\03\09\05\01\15\05F\07\0B\03\03\03\17\0FB\15\0D\03\03\15F\15\0F\03\03\05\19\1B\0FB\05\05\03\01\13F\05\09\03\01\05\1F\0D\0FB\05\07\03\01\11F\05\09\03\01\05!#\03\06\05\03\09\05\03%\05F\05\0B\03\03\03'\0BD\05\11\05'\1D\0D\00\01\06\03\01\05\01\00\F2\05+\A5\0B\A3\0F\11!\85\0B\0B\0B\13\0F\0D\1F\0B\0B\0F\0F\0D\07\11builtin\00tt\00arith\00module\00addptr\00load\00func\00get_program_id\00store\00return\00constant\00muli\00addi\00addf\00third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py\00tt.divisibility\00add_one\00public\00/get[tree=PyTreeDef((CustomNode(NDIndexer[(PyTreeDef((*,)), (1,), ())], [*]),))]\00/add\00/swap[tree=PyTreeDef((CustomNode(NDIndexer[(PyTreeDef((*,)), (1,), ())], [*]),))]\00\08C\13\05\01\01\0B/\1D\01\1FC\03\09\03\03\03[\11\17\07\07'\01\07\01\03\03%\03_\07\17\07\07", name = "add_one", num_stages = 3 : i32, num_warps = 4 : i32}, operand_layouts = [dense<0> : tensor<1xindex>], result_layouts = [dense<0> : tensor<1xindex>]} : (tensor<8xf32>) -> tensor<8xf32> loc(#loc4)
    return %0 : tensor<8xf32> loc(#loc3)
  } loc(#loc3)
} loc(#loc)
#loc = loc(unknown)
#loc4 = loc("jit(func)/jit(main)/jit(wrapped)/pallas_call[name=add_one which_linear=(False,) in_shapes=(ShapeDtypeStruct(shape=(8,), dtype=float32),) out_shapes=(ShapeDtypeStruct(shape=(8,), dtype=float32),) debug=False interpret=False grid_mapping=GridMapping(grid=(8,), block_mappings=(BlockMapping(block_shape=(1,), index_map_jaxpr={ lambda ; a:i32[]. let  in (a,) }, indexing_mode=<jax._src.pallas.core.Blocked object at 0x72127ace4e90>), BlockMapping(block_shape=(1,), index_map_jaxpr={ lambda ; a:i32[]. let  in (a,) }, indexing_mode=<jax._src.pallas.core.Blocked object at 0x72127ace4e90>)), mapped_dims=(), num_index_operands=0, num_scratch_operands=0) input_output_aliases=() compiler_params={}]"(#loc2))
""",
    mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x19\x05\x01\x03\x01\x03\x05\x03\t\x07\t\x0b\r\x03\xaf\x8b\x11\x01G\x07\x0f\x0b\x0f\x0b\x0b\x0b\x0b\x13+\x0b\x0f\x0b\x0b\x0b33\x0b\x0bS\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0b\x13\x0b\x03E\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0f\x13\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x0bK\x0b\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f/\x01\x05\x0b\x0f\x03\r\x13\x07\x17\x07\x13\x07\x02\xee\x03\x1f\x1d#\x11\x05\x0f\x11\x03\x05\x05\x11\x05\x13\x05\x15\x05\x17\x17%W\x1b\x03\t\x15\x17\x19\x07\x1b\x07\x05\x1d\x05\x19\x11\x01\x00\x05\x1b\x05\x1d\x05\x1f\x03\x0b\tG\x0bM\r]\x05c\x0fe\x03\x0b\tG\x0bM\rG\x05Q\x0fg\x05!\x05#\x03\x13)i+O-k/S1U3m5Y7S9Y\x05%\x05'\x05)\x05+\x05-\x05/\x051\x053\x055\x1d=\x11\x057\x1dA\x01\x059\x03\x03EQ\x05;\x03\x03[\x1d=\x1d?#\t\x1dA\x1dC\x03\x01\x05\x01\x13\x07\x05\x03\x03\x89\r\x03IK\x03\x03_\r\x05aOIK\x1dE\x1dG\x1dI\x1dK\x0b\x03\x1dM\r\x11oUqsuWwWy{}\x7f\x81\x83\x85\x87\x1dO\x1dQ\x13\x07!\x1dS\x1dU\x1dW\x1dY\x1d[\x1d]\x1d_\x13\x07\r\x1da\x13\x07\x11\x1f\r\x11\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x03!\x0b\x1b\x11\x03\x05\x03\x05\t)\x03\x05\x0f\x13\x04s\x05\x01\x11\x01\x13\x07\x03\x01\t\x03\x11\x01\x1f\x07\x03\x05\x0b\x03\x05?\t\x07\x03C\x03\x05\x03\x01\x05\x04\x01\x03\x03\x03\x11\x03!\x07\x03\x05\x0b\x03\x05\x03\x07\x07;'\x03\x05\x03\x01\x05\x04\x03\x03\x03\x06\x03\x01\x05\x01\x00\x12%c\x15\x17\x11\x0b\xfe\x0c\x07\x0f\x0f\x0f\r+\x11\x0f\x0b!\x11\x03\x11#\x0f\x05\xd2\n\x1f/!)!)#\x1f\x19\x85j\x03\x13%)9\x1f\x15\x1d\x15\x13\x11\x1f\x15\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00return_v1\x00custom_call_v1\x00call_v1\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=wrapped keep_unused=False inline=False]\x00third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit(func)/jit(main)/jit(wrapped)/pallas_call[name=add_one which_linear=(False,) in_shapes=(ShapeDtypeStruct(shape=(8,), dtype=float32),) out_shapes=(ShapeDtypeStruct(shape=(8,), dtype=float32),) debug=False interpret=False grid_mapping=GridMapping(grid=(8,), block_mappings=(BlockMapping(block_shape=(1,), index_map_jaxpr={ lambda ; a:i32[]. let  in (a,) }, indexing_mode=<jax._src.pallas.core.Blocked object at 0x72127ace4e90>), BlockMapping(block_shape=(1,), index_map_jaxpr={ lambda ; a:i32[]. let  in (a,) }, indexing_mode=<jax._src.pallas.core.Blocked object at 0x72127ace4e90>)), mapped_dims=(), num_index_operands=0, num_scratch_operands=0) input_output_aliases=() compiler_params={}]\x00x\x00callee\x00mhlo.layout_mode\x00default\x00\x00wrapped\x00jax.result_info\x00main\x00public\x00private\x00__gpu$xla.gpu.triton\x00debug\x00grid_x\x00grid_y\x00grid_z\x00ir\x00ML\xefR\rMLIRgooglex-trunk\x00\x01-\x07\x01\x05\t\x17\x01\x03\x0f\x03\r\x13\x17\x1b\x1f#'\x05\t+/37\x03O1\x0b\x01-\x07\x0f\x0f\x0f\x0f\x13\x13\x13\x0b\x0f\x0f\x13\x0b\x0f\x0b\x0b\x0b\x0b\x1f\x0b\x0b\x13\x05\x05YY\x01\t\x0f\x07\x17\x0b\x03\x035\x02\x16\x02\x1f\x11\x01\x05\x1d)+\x1d#\x0f\x11\x01\x01#\x01\x01\x01\x03\x03\x19\x1b\x17\x11U'\x05\x1d\x11\x07\x00\x1d'\x0f\x01\x05\r\r\x05\x1f\x11\x01\x81\r\x05\x05!\x05#\x05%\x13\x03\x10\x00\x00\xe0\x0f\x05'\x05)\x17\x11U\x11#arith.overflow<none>\x00#arith.fastmath<none>\x00\x01\x02\x02\x0b\x05\x05\t\t\x01\x01\t!tt.ptr<f32>\x00\x04\xd2\x02\x05\x01P\x01\x01\x07\x04\xae\x02\x03\x01\x05\x07P\x01\x03\x07\x04\x82\x02\x03+W\x05\x11\x11\x00\tB\x01\x05\x03\x01\x0fB\x01\x07\x03\x01\x11F\x01\t\x03\x01\x05\x05\x07\x0fB\x01\x07\x03\x01\x11F\x01\t\x03\x01\x05\x05\x0b\x0fB\x07\x05\x03\x01\x13F\x07\t\x03\x01\x05\x0f\t\x0fB\x07\x07\x03\x01\x11F\x07\t\x03\x01\x05\x11\x13\x03\x06\x07\x03\t\x05\x01\x15\x05F\x07\x0b\x03\x03\x03\x17\x0fB\x15\r\x03\x03\x15F\x15\x0f\x03\x03\x05\x19\x1b\x0fB\x05\x05\x03\x01\x13F\x05\t\x03\x01\x05\x1f\r\x0fB\x05\x07\x03\x01\x11F\x05\t\x03\x01\x05!#\x03\x06\x05\x03\t\x05\x03%\x05F\x05\x0b\x03\x03\x03'\x0bD\x05\x11\x05'\x1d\r\x00\x01\x06\x03\x01\x05\x01\x00\xf2\x05+\xa5\x0b\xa3\x0f\x11!\x85\x0b\x0b\x0b\x13\x0f\r\x1f\x0b\x0b\x0f\x0f\r\x07\x11builtin\x00tt\x00arith\x00module\x00addptr\x00load\x00func\x00get_program_id\x00store\x00return\x00constant\x00muli\x00addi\x00addf\x00third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py\x00tt.divisibility\x00add_one\x00public\x00/get[tree=PyTreeDef((CustomNode(NDIndexer[(PyTreeDef((*,)), (1,), ())], [*]),))]\x00/add\x00/swap[tree=PyTreeDef((CustomNode(NDIndexer[(PyTreeDef((*,)), (1,), ())], [*]),))]\x00\x08C\x13\x05\x01\x01\x0b/\x1d\x01\x1fC\x03\t\x03\x03\x03[\x11\x17\x07\x07'\x01\x07\x01\x03\x03%\x03_\x07\x17\x07\x07\x00name\x00add_one\x00num_stages\x00num_warps\x00",
    xla_call_module_version=9,
    nr_devices=1,
)  # End paste
